{{def_kernel("A", "B", "A_inverse_scale", "B_inverse_scale")}}
    M = {{size("A", 0)}}
    N = {{size("B", 1)}}
    K = {{size("A", 1)}}
    if M * N == 0:
        # early exit due to zero-size input(s)
        return

    stride_am = {{stride("A", 0)}}
    stride_bn = {{stride("B", 1)}}

    start_pid = tl.program_id(axis=0).to(INDEX_DTYPE)
    num_pid_m = tl.cdiv(M, BLOCK_M)
    num_pid_n = tl.cdiv(N, BLOCK_N)
    k_tiles = tl.cdiv(K, BLOCK_K)
    num_tiles = num_pid_m * num_pid_n

    a_desc = triton.language.make_tensor_descriptor(
        base=A,
        shape=[M, K],
        strides=[stride_am, 1],
        block_shape=[BLOCK_M, BLOCK_K],
    )
    b_desc = triton.language.make_tensor_descriptor(
        base=B,
        shape=[N, K],
        strides=[stride_bn, 1],
        block_shape=[BLOCK_N, BLOCK_K],
    )

    tiles_per_SM = num_tiles // NUM_SMS
    if start_pid < num_tiles % NUM_SMS:
        tiles_per_SM += 1

    tile_id = start_pid - NUM_SMS
    ki = -1

    pid_m = 0
    pid_n = 0
    offs_am = 0
    offs_bn = 0

    num_pid_in_group = GROUP_M * num_pid_n
    accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
    a_scale = load_scales(A_inverse_scale, SCALE_RECIPE_A)
    b_scale = load_scales(B_inverse_scale, SCALE_RECIPE_B)

    for _ in range(0, k_tiles * tiles_per_SM):
        ki = tl.where(ki == k_tiles - 1, 0, ki + 1)
        if ki == 0:
            tile_id += NUM_SMS
            group_id = tile_id // num_pid_in_group
            first_pid_m = group_id * GROUP_M
            group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
            pid_m = first_pid_m + (tile_id % group_size_m)
            pid_n = (tile_id % num_pid_in_group) // group_size_m

            offs_am = pid_m * BLOCK_M
            offs_bn = pid_n * BLOCK_N

        offs_k = ki * BLOCK_K

        a = tl.load_tensor_descriptor(a_desc, [offs_am, offs_k])
        b = tl.load_tensor_descriptor(b_desc, [offs_bn, offs_k])

        am_blocks = tl.cdiv(M, TILE_SIZE_A)
        ak_blocks = tl.cdiv(K, TILE_SIZE_A)
        bn_blocks = tl.cdiv(N, TILE_SIZE_B)
        bk_blocks = tl.cdiv(K, TILE_SIZE_B)

        {%- if SCALE_RECIPE_A == 5 %}  # ScalingType.Blockwise128x128
        scale_a_block = blockwise128x128_scaling(
            pid_m,
            a_scale,
            ki,
            am_blocks,
            ak_blocks,
            BLOCK_M,
            BLOCK_K,
            MIN_BLOCK_TILE_AM,
            MIN_BLOCK_TILE_AK,
        )
        {%- else %}  # ScalingType.Blockwise1xTILESIZE
        scale_a_block = blockwise1xTILESIZE_scaling(
            pid_m,
            a_scale,
            ki,
            M,
            am_blocks,
            ak_blocks,
            BLOCK_M,
            BLOCK_K,
            MIN_BLOCK_TILE_AK,
            TILE_SIZE_A,
        )
        {%- endif %}

        {%- if SCALE_RECIPE_A == 5 %}  # ScalingType.Blockwise128x128
        scale_b_block = blockwise128x128_scaling(
            pid_n,
            b_scale,
            ki,
            bn_blocks,
            bk_blocks,
            BLOCK_N,
            BLOCK_K,
            MIN_BLOCK_TILE_BN,
            MIN_BLOCK_TILE_BK,
        )
        {%- else %}  # ScalingType.Blockwise1xTILESIZE
        scale_b_block = blockwise1xTILESIZE_scaling(
            pid_n,
            b_scale,
            ki,
            N,
            bn_blocks,
            bk_blocks,
            BLOCK_N,
            BLOCK_K,
            MIN_BLOCK_TILE_BK,
            TILE_SIZE_B,
        )
        {%- endif %}

        a_scaled = a * scale_a_block
        b_scaled = b * scale_b_block
        accumulator = tl.dot(a_scaled, b_scaled.T, accumulator)

        if ki == k_tiles - 1:
            offs_cm = offs_am + tl.arange(0, BLOCK_M)
            offs_cn = offs_bn + tl.arange(0, BLOCK_N)

            # inductor generates a suffix
            {{store_output(
                ("offs_am", "offs_bn"),
                "accumulator",
                indent_width=12,
                val_shape=("BLOCK_M", "BLOCK_N"),
                block_indexing=True,
            )}}
            accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)


@triton.jit
def load_scales(scale_ptr, SCALE_RECIPE: tl.constexpr):
    if SCALE_RECIPE == 0:
        return tl.load(scale_ptr)  # For tensor-wise scaling, we'll load the scalar values
    else:
        return scale_ptr  # For all other scaling recipes, we'll return the pointers


@triton.jit
def blockwise1xTILESIZE_scaling(
    pid,
    scale,
    ki,
    lhs_size,
    lhs_blocks,
    k_blocks,
    BLOCK_lhs: tl.constexpr,
    BLOCK_K: tl.constexpr,
    MIN_BLOCK_TILE_K: tl.constexpr,
    TILE_SIZE: tl.constexpr,
):
    row_offs_scale = pid * BLOCK_lhs + tl.arange(0, BLOCK_lhs)
    col_offs_scale = ki * tl.cdiv(BLOCK_K, TILE_SIZE) + tl.arange(0, (BLOCK_K + TILE_SIZE - 1) // TILE_SIZE)
    ptrs = scale + row_offs_scale[:, None] * k_blocks + col_offs_scale[None, :]
    mask = (row_offs_scale[:, None] < lhs_size) & (col_offs_scale[None, :] < k_blocks)
    scale_block = tl.load(ptrs, mask=mask, other=1.0)

    scale_expanded = scale_block[:, :, None]
    scale_expanded = tl.broadcast_to(
        scale_expanded,
        (BLOCK_lhs, (BLOCK_K + TILE_SIZE - 1) // TILE_SIZE, MIN_BLOCK_TILE_K)
    )
    scale_expanded = scale_expanded.reshape(
        BLOCK_lhs,
        ((BLOCK_K + TILE_SIZE - 1) // TILE_SIZE) * MIN_BLOCK_TILE_K
    )

    return scale_expanded


@triton.jit
def blockwise128x128_scaling(
    pid,
    scale,
    ki,
    lhs_blocks,
    k_blocks,
    BLOCK_lhs: tl.constexpr,
    BLOCK_K: tl.constexpr,
    MIN_BLOCK_TILE_lhs: tl.constexpr,
    MIN_BLOCK_TILE_K: tl.constexpr,
):
    row_offs_scale = pid * tl.cdiv(BLOCK_lhs, 128) + tl.arange(0, (BLOCK_lhs + 128 - 1) // 128)
    col_offs_scale = ki * tl.cdiv(BLOCK_K, 128) + tl.arange(0, (BLOCK_K + 128 - 1) // 128)
    ptrs = scale + row_offs_scale[:, None] * k_blocks + col_offs_scale[None, :]
    mask = (row_offs_scale[:, None] < lhs_blocks) & (col_offs_scale[None, :] < k_blocks)
    scale_block = tl.load(ptrs, mask=mask, other=1.0)

    scale_expanded = scale_block[:, :, None, None]
    scale_expanded = tl.broadcast_to(
        scale_expanded,
        ((BLOCK_lhs + 128 - 1) // 128, (BLOCK_K + 128 - 1) // 128, MIN_BLOCK_TILE_lhs, MIN_BLOCK_TILE_K)
    )
    scale_expanded = scale_expanded.reshape(
        ((BLOCK_lhs + 128 - 1) // 128) * MIN_BLOCK_TILE_lhs,
        ((BLOCK_K + 128 - 1) // 128) * MIN_BLOCK_TILE_K
    )

    return scale_expanded
